import gym
import matplotlib.pyplot as plt
from stable_baselines3 import PPO, A2C
from stable_baselines3.common.evaluation import evaluate_policy

# 创建环境
env = gym.make("CartPole-v1")

# 训练 PPO 模型
ppo_model = PPO("MlpPolicy", env, verbose=1)
ppo_model.learn(total_timesteps=100_000)

# 评估 PPO 模型
ppo_rewards, _ = evaluate_policy(ppo_model, env, n_eval_episodes=20, return_episode_rewards=True)

# 训练 A2C 模型
a2c_model = A2C("MlpPolicy", env, verbose=1)
a2c_model.learn(total_timesteps=100_000)

# 评估 A2C 模型
a2c_rewards, _ = evaluate_policy(a2c_model, env, n_eval_episodes=20, return_episode_rewards=True)

# 可视化比较
plt.figure(figsize=(10, 6))
plt.plot(ppo_rewards, label="PPO")
plt.plot(a2c_rewards, label="A2C")
plt.xlabel("Episode")
plt.ylabel("Reward")
plt.title("PPO vs A2C Evaluation Rewards")
plt.legend()
plt.grid(True)
plt.savefig("ppo_a2c_rewards_comparison.png")
plt.show()

env.close()
